{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 外部代码生成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from collections import OrderedDict\n",
    "import numpy as np\n",
    "import tvm\n",
    "import tvm.testing\n",
    "from tvm.relay.build_module import bind_params_by_name\n",
    "from tvm.relay.op.annotation import compiler_begin, compiler_end\n",
    "from tvm import relay, runtime, testing\n",
    "from tvm.contrib import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_external_func_attr(func, compiler, ext_symbol):\n",
    "    func = func.with_attr(\"Primitive\", tvm.tir.IntImm(\"int32\", 1))\n",
    "    func = func.with_attr(\"Compiler\", compiler)\n",
    "    func = func.with_attr(\"global_symbol\", ext_symbol)\n",
    "    return func\n",
    "\n",
    "def update_lib(lib, source_dir):\n",
    "    source_dir = Path(source_dir)\n",
    "    contrib_path = source_dir/\"src/runtime/contrib\"\n",
    "\n",
    "    kwargs = {}\n",
    "    kwargs[\"options\"] = [\"-O2\", \"-std=c++17\", f\"-I{contrib_path}\"]\n",
    "    tmp_path = utils.tempdir()\n",
    "    lib_name = \"lib.so\"\n",
    "    lib_path = tmp_path.relpath(lib_name)\n",
    "    lib.export_library(lib_path, fcompile=False, **kwargs)\n",
    "    lib = tvm.runtime.load_module(lib_path)\n",
    "    return lib\n",
    "\n",
    "def check_result(\n",
    "    mod, map_inputs, out_shape, result, tol=1e-5, \n",
    "    target=\"llvm\", device=tvm.cpu(), \n",
    "    source_dir=\"/media/pc/data/board/arria10/lxw/tasks/tvm\"):\n",
    "    with tvm.transform.PassContext(opt_level=3, disabled_pass=[\"AlterOpLayout\"]):\n",
    "        exe = relay.vm.compile(mod, target=target)\n",
    "    code, lib = exe.save()\n",
    "    lib = update_lib(lib, source_dir=source_dir)\n",
    "    exe = runtime.vm.Executable.load_exec(code, lib)\n",
    "    vm = runtime.vm.VirtualMachine(exe, device)\n",
    "    out = vm.run(**map_inputs)\n",
    "    tvm.testing.assert_allclose(out.numpy(), result, rtol=tol, atol=tol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 多节点子图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = relay.var(\"x\", shape=(10, 10))\n",
    "w0 = relay.var(\"w0\", shape=(10, 10))\n",
    "w1 = relay.var(\"w1\", shape=(10, 10))\n",
    "w2 = relay.var(\"w2\", shape=(10, 10))\n",
    "w3 = relay.var(\"w3\", shape=(10, 10))\n",
    "w4 = relay.var(\"w4\", shape=(10, 10))\n",
    "w5 = relay.var(\"w5\", shape=(10, 10))\n",
    "w6 = relay.var(\"w6\", shape=(10, 10))\n",
    "w7 = relay.var(\"w7\", shape=(10, 10))\n",
    "\n",
    "# subgraph0\n",
    "x0 = relay.var(\"x0\", shape=(10, 10))\n",
    "w00 = relay.var(\"w00\", shape=(10, 10))\n",
    "w01 = relay.var(\"w01\", shape=(10, 10))\n",
    "w02 = relay.var(\"w02\", shape=(10, 10))\n",
    "z00 = relay.add(x0, w00)\n",
    "p00 = relay.subtract(z00, w01)\n",
    "q00 = relay.multiply(p00, w02)\n",
    "subgraph0 = relay.Function([x0, w00, w01, w02], q00)\n",
    "subgraph0 = set_external_func_attr(subgraph0, \"ccompiler\", \"ccompiler_0\")\n",
    "call0 = relay.Call(subgraph0, [x, w0, w1, w2])\n",
    "\n",
    "# subgraph1\n",
    "x1 = relay.var(\"x1\", shape=(10, 10))\n",
    "w10 = relay.var(\"w10\", shape=(10, 10))\n",
    "w11 = relay.var(\"w11\", shape=(10, 10))\n",
    "w12 = relay.var(\"w12\", shape=(10, 10))\n",
    "z10 = relay.add(x1, w10)\n",
    "p10 = relay.subtract(z10, w11)\n",
    "q10 = relay.multiply(p10, w12)\n",
    "subgraph1 = relay.Function([x1, w10, w11, w12], q10)\n",
    "subgraph1 = set_external_func_attr(subgraph1, \"ccompiler\", \"ccompiler_1\")\n",
    "call1 = relay.Call(subgraph1, [x, w3, w4, w5])\n",
    "\n",
    "# Other parts on TVM\n",
    "z2 = relay.add(x, w6)\n",
    "q2 = relay.subtract(z2, w7)\n",
    "\n",
    "r = relay.concatenate((call0, call1, q2), axis=0)\n",
    "f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r)\n",
    "mod = tvm.IRModule()\n",
    "mod[\"main\"] = f\n",
    "mod = relay.transform.InferType()(mod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #A2F\">@main</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w0: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w1: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w2: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w3: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w4: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w5: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w6: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w7: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">30</span>, <span style=\"color: #008000\">10</span>), float32] {\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span> <span style=\"color: #A2F; font-weight: bold\">=</span> fn (<span style=\"color: #A2F; font-weight: bold\">%</span>x0: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w00: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w01: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w02: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, Primitive<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, Compiler<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler&quot;</span>, global_symbol<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler_0&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] {\n",
       "    <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">=</span> add(<span style=\"color: #A2F; font-weight: bold\">%</span>x0, <span style=\"color: #A2F; font-weight: bold\">%</span>w00) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "    <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #A2F; font-weight: bold\">=</span> subtract(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w01) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "    multiply(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w02) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>\n",
       "  } <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>fn (Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32], Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32], Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32], Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32]) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">5</span> <span style=\"color: #A2F; font-weight: bold\">=</span> fn (<span style=\"color: #A2F; font-weight: bold\">%</span>x1: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w10: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w11: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w12: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>, Primitive<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, Compiler<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler&quot;</span>, global_symbol<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler_1&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] {\n",
       "    <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">3</span> <span style=\"color: #A2F; font-weight: bold\">=</span> add(<span style=\"color: #A2F; font-weight: bold\">%</span>x1, <span style=\"color: #A2F; font-weight: bold\">%</span>w10) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "    <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">4</span> <span style=\"color: #A2F; font-weight: bold\">=</span> subtract(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">3</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w11) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "    multiply(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">4</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w12) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>\n",
       "  } <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>fn (Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32], Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32], Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32], Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32]) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">6</span> <span style=\"color: #A2F; font-weight: bold\">=</span> add(<span style=\"color: #A2F; font-weight: bold\">%</span>x, <span style=\"color: #A2F; font-weight: bold\">%</span>w6) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">7</span> <span style=\"color: #A2F; font-weight: bold\">=</span> <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x, <span style=\"color: #A2F; font-weight: bold\">%</span>w0, <span style=\"color: #A2F; font-weight: bold\">%</span>w1, <span style=\"color: #A2F; font-weight: bold\">%</span>w2) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">8</span> <span style=\"color: #A2F; font-weight: bold\">=</span> <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">5</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x, <span style=\"color: #A2F; font-weight: bold\">%</span>w3, <span style=\"color: #A2F; font-weight: bold\">%</span>w4, <span style=\"color: #A2F; font-weight: bold\">%</span>w5) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">9</span> <span style=\"color: #A2F; font-weight: bold\">=</span> subtract(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">6</span>, <span style=\"color: #A2F; font-weight: bold\">%</span>w7) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">10</span> <span style=\"color: #A2F; font-weight: bold\">=</span> (<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">7</span>, <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">8</span>, <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">9</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>(Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32], Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32], Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32]) <span style=\"color: #A2F; font-weight: bold\">*/</span>;\n",
       "  concatenate(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">10</span>) <span style=\"color: #A2F; font-weight: bold\">/*</span> ty<span style=\"color: #A2F; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">30</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #A2F; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_data = np.random.rand(10, 10).astype(\"float32\")\n",
    "w_data = []\n",
    "for _ in range(8):\n",
    "    w_data.append(np.random.rand(10, 10).astype(\"float32\"))\n",
    "\n",
    "map_inputs = OrderedDict([(\"x\", x_data)] + [(\"w{}\".format(i), w_data[i]) for i in range(8)])\n",
    "out_shape = (30, 10)\n",
    "result = np.concatenate(\n",
    "    (\n",
    "        ((x_data + w_data[0]) - w_data[1]) * w_data[2],\n",
    "        ((x_data + w_data[3]) - w_data[4]) * w_data[5],\n",
    "        x_data + w_data[6] - w_data[7],\n",
    "    ),\n",
    "    axis=0,\n",
    ")\n",
    "check_result(\n",
    "    mod, map_inputs, out_shape, result, \n",
    "    tol=1e-5, target=\"llvm\", device=tvm.cpu(), \n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 外部 gcc 单个算子"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[10:13:21] /media/pc/data/board/arria10/lxw/tasks/tvm/src/relay/backend/vm/compiler.cc:1199: All lowered functions have been build by BYOC -- generating an empty TVM module\n"
     ]
    }
   ],
   "source": [
    "x = relay.var(\"x\", shape=(8, 8))\n",
    "y = relay.var(\"y\", shape=(8, 8))\n",
    "\n",
    "x0 = relay.var(\"x0\", shape=(8, 8))\n",
    "y0 = relay.var(\"y0\", shape=(8, 8))\n",
    "z = x0 + y0\n",
    "f = relay.Function([x0, y0], z)\n",
    "f = set_external_func_attr(f, \"ccompiler\", \"ccompiler_0\")\n",
    "call = relay.Call(f, [x, y])\n",
    "mod = tvm.IRModule.from_expr(call)\n",
    "x_data = np.random.rand(8, 8).astype(\"float32\")\n",
    "y_data = np.random.rand(8, 8).astype(\"float32\")\n",
    "\n",
    "check_result(mod, {\"x\": x_data, \"y\": y_data}, (8, 8), x_data + y_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #A2F\">@main</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), float32], <span style=\"color: #A2F; font-weight: bold\">%</span>y: Tensor[(<span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), float32]) {\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">=</span> fn (<span style=\"color: #A2F; font-weight: bold\">%</span>x0: Tensor[(<span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), float32], <span style=\"color: #A2F; font-weight: bold\">%</span>y0: Tensor[(<span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), float32], Primitive<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, Compiler<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler&quot;</span>, global_symbol<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler_0&quot;</span>) {\n",
       "    add(<span style=\"color: #A2F; font-weight: bold\">%</span>x0, <span style=\"color: #A2F; font-weight: bold\">%</span>y0)\n",
       "  };\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x, <span style=\"color: #A2F; font-weight: bold\">%</span>y)\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #A2F\">@main</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), int32], <span style=\"color: #A2F; font-weight: bold\">%</span>y: Tensor[(<span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), int32]) {\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">=</span> fn (<span style=\"color: #A2F; font-weight: bold\">%</span>x0: Tensor[(<span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), int32], <span style=\"color: #A2F; font-weight: bold\">%</span>y0: Tensor[(<span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">8</span>), int32], Primitive<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, Compiler<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler&quot;</span>, global_symbol<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler_0&quot;</span>) {\n",
       "    add(<span style=\"color: #A2F; font-weight: bold\">%</span>x0, <span style=\"color: #A2F; font-weight: bold\">%</span>y0)\n",
       "  };\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x, <span style=\"color: #A2F; font-weight: bold\">%</span>y)\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[10:15:42] /media/pc/data/board/arria10/lxw/tasks/tvm/src/relay/backend/vm/compiler.cc:1199: All lowered functions have been build by BYOC -- generating an empty TVM module\n"
     ]
    }
   ],
   "source": [
    "x = relay.var(\"x\", shape=(8, 8), dtype=\"int32\")\n",
    "y = relay.var(\"y\", shape=(8, 8), dtype=\"int32\")\n",
    "\n",
    "x0 = relay.var(\"x0\", shape=(8, 8), dtype=\"int32\")\n",
    "y0 = relay.var(\"y0\", shape=(8, 8), dtype=\"int32\")\n",
    "z = x0 + y0\n",
    "f = relay.Function([x0, y0], z)\n",
    "f = set_external_func_attr(f, \"ccompiler\", \"ccompiler_0\")\n",
    "call = relay.Call(f, [x, y])\n",
    "mod = tvm.IRModule.from_expr(call)\n",
    "x_data = np.random.rand(8, 8).astype(\"int32\")\n",
    "y_data = np.random.rand(8, 8).astype(\"int32\")\n",
    "mod.show()\n",
    "check_result(mod, {\"x\": x_data, \"y\": y_data}, (8, 8), x_data + y_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 外部 gcc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[10:13:42] /media/pc/data/board/arria10/lxw/tasks/tvm/src/relay/backend/vm/compiler.cc:1199: All lowered functions have been build by BYOC -- generating an empty TVM module\n"
     ]
    }
   ],
   "source": [
    "x = relay.var(\"x\", shape=(2, 2))\n",
    "y = relay.var(\"y\", shape=(2, 2))\n",
    "\n",
    "# subgraph for mul\n",
    "x0 = relay.var(\"x0\", shape=(2, 2))\n",
    "y0 = relay.var(\"y0\", shape=(2, 2))\n",
    "mul = x0 * y0\n",
    "mul = relay.Function([x0, y0], mul)\n",
    "mul = set_external_func_attr(mul, \"ccompiler\", \"ccompiler_2\")\n",
    "call_mul = relay.Call(mul, [y, y])\n",
    "\n",
    "# subgraph for add\n",
    "x1 = relay.var(\"x1\", shape=(2, 2))\n",
    "y1 = relay.var(\"y1\", shape=(2, 2))\n",
    "add = x1 + y1\n",
    "add = relay.Function([x1, y1], add)\n",
    "add = set_external_func_attr(add, \"ccompiler\", \"ccompiler_1\")\n",
    "call_add = relay.Call(add, [x, x])\n",
    "\n",
    "# subgraph for sub\n",
    "x2 = relay.var(\"x2\", shape=(2, 2))\n",
    "y2 = relay.var(\"y2\", shape=(2, 2))\n",
    "sub = x2 - y2\n",
    "sub = relay.Function([x2, y2], sub)\n",
    "sub = set_external_func_attr(sub, \"ccompiler\", \"ccompiler_0\")\n",
    "call_sub = relay.Call(sub, [call_mul, call_add])\n",
    "mod = tvm.IRModule.from_expr(call_sub)\n",
    "\n",
    "x_data = np.random.rand(2, 2).astype(\"float32\")\n",
    "y_data = np.random.rand(2, 2).astype(\"float32\")\n",
    "\n",
    "inputs = OrderedDict(\n",
    "    [\n",
    "        (\"y\", y_data),\n",
    "        (\"x\", x_data),\n",
    "    ]\n",
    ")\n",
    "\n",
    "check_result(mod, inputs, (2, 2), (y_data * y_data) - (x_data + x_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #A2F\">@main</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>y: Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), float32], <span style=\"color: #A2F; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), float32]) {\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #A2F; font-weight: bold\">=</span> fn (<span style=\"color: #A2F; font-weight: bold\">%</span>x0: Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), float32], <span style=\"color: #A2F; font-weight: bold\">%</span>y0: Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), float32], Primitive<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, Compiler<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler&quot;</span>, global_symbol<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler_2&quot;</span>) {\n",
       "    multiply(<span style=\"color: #A2F; font-weight: bold\">%</span>x0, <span style=\"color: #A2F; font-weight: bold\">%</span>y0)\n",
       "  };\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span> <span style=\"color: #A2F; font-weight: bold\">=</span> fn (<span style=\"color: #A2F; font-weight: bold\">%</span>x1: Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), float32], <span style=\"color: #A2F; font-weight: bold\">%</span>y1: Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), float32], Primitive<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, Compiler<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler&quot;</span>, global_symbol<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler_1&quot;</span>) {\n",
       "    add(<span style=\"color: #A2F; font-weight: bold\">%</span>x1, <span style=\"color: #A2F; font-weight: bold\">%</span>y1)\n",
       "  };\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span> <span style=\"color: #A2F; font-weight: bold\">=</span> <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>y, <span style=\"color: #A2F; font-weight: bold\">%</span>y);\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">3</span> <span style=\"color: #A2F; font-weight: bold\">=</span> <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">1</span>(<span style=\"color: #A2F; font-weight: bold\">%</span>x, <span style=\"color: #A2F; font-weight: bold\">%</span>x);\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">4</span> <span style=\"color: #A2F; font-weight: bold\">=</span> fn (<span style=\"color: #A2F; font-weight: bold\">%</span>x2: Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), float32], <span style=\"color: #A2F; font-weight: bold\">%</span>y2: Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), float32], Primitive<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, Compiler<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler&quot;</span>, global_symbol<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;ccompiler_0&quot;</span>) {\n",
       "    subtract(<span style=\"color: #A2F; font-weight: bold\">%</span>x2, <span style=\"color: #A2F; font-weight: bold\">%</span>y2)\n",
       "  };\n",
       "  <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">4</span>(<span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">2</span>, <span style=\"color: #A2F; font-weight: bold\">%</span><span style=\"color: #008000\">3</span>)\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "aix",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
